/*
 * Written by Dawid Kurzyniec and released to the public domain, as explained
 * at http://creativecommons.org/licenses/publicdomain
 */

package edu.emory.mathcs.util.net.tunnel;

import java.io.*;
import java.net.*;

/**
 * @see TunnelServerSocket
 *
 * @author Dawid Kurzyniec
 * @version 1.0
 */
public class TunnelSocket extends Socket {

    int remotePort;
    Socket tunnel;
    volatile int soTimeout = 0;

    volatile boolean closed;
    volatile boolean connected;
    volatile boolean isConnecting;

    /**
     * @throws IOException if I/O error occurs
     */
    public TunnelSocket(Socket tunnel, TunnelSocketAddress addr) throws IOException {
        this(tunnel);
        connect(addr);
    }

    public TunnelSocket(TunnelSocketAddress addr) throws IOException {
        this();
        connect(addr);
    }

    public TunnelSocket(InetSocketAddress addr, int port) throws IOException {
        this();
        connect(new TunnelSocketAddress(addr, port));
    }

    public TunnelSocket(Socket tunnel) throws IOException {
        if (tunnel == null) throw new NullPointerException();
        this.tunnel = tunnel;
    }

    public TunnelSocket() throws IOException {
        this.tunnel = new Socket();
    }

    public TunnelSocket(Socket tunnel, int port) throws IOException {
        this(tunnel, port, 0);
    }

    public TunnelSocket(Socket tunnel, int port, int timeout) throws IOException {
        this(tunnel);
        connectOverConnected(tunnel, port, timeout);
    }

    public void connect(SocketAddress endpoint, int timeout) throws IOException {
        if (!(endpoint instanceof TunnelSocketAddress)) {
            throw new IllegalArgumentException("connect: The address must be " +
                                               "an instance of TunnelSocketAddress");
        }
        if (timeout < 0) {
            throw new IllegalArgumentException("connect: timeout can't be negative");
        }

        if (isClosed()) throw new SocketException("Socket is closed");
        TunnelSocketAddress addr = (TunnelSocketAddress)endpoint;
        SocketAddress tunnelAddr = addr.getTunnelAddress();
        int port = addr.getPort();

        /** @todo security check? */
        synchronized (this) {
            if (isConnecting) {
                throw new IOException("Connect already in progress");
            }
            isConnecting = true;
        }
        try {
            tunnel.connect(tunnelAddr);
            handShake(tunnel, port, timeout);
            setConnected(remotePort);
        }
        catch (IOException e) {
            // failed
            close();
            throw e;
        }
        finally {
            isConnecting = false;
        }
    }

    void connectOverConnected(Socket tunnel, int port, int timeout) throws IOException {
        if (timeout < 0) {
            throw new IllegalArgumentException("connect: timeout can't be negative");
        }
        if (!tunnel.isConnected()) {
            throw new IllegalArgumentException("Tunnel must be connected");
        }

        if (isClosed()) throw new SocketException("Socket is closed");

        /** @todo security check? */
        synchronized (this) {
            if (isConnecting) {
                throw new IOException("Connect already in progress");
            }
            isConnecting = true;
        }
        try {
            handShake(tunnel, port, timeout);
            setConnected(remotePort);
        }
        catch (IOException e) {
            // failed
            close();
            throw e;
        }
        finally {
            isConnecting = false;
        }

    }

    private void handShake(Socket tunnel, int port, int timeout) throws IOException {
        long deadline = 0;
        if (timeout > 0) {
            deadline = System.currentTimeMillis() + timeout;
            tunnel.setSoTimeout(timeout);
        }

        DataOutputStream out = new DataOutputStream(tunnel.getOutputStream());
        out.writeInt(TunnelServerSocket.PROTOCOL_MAGIC);
        out.writeInt(TunnelServerSocket.PROTOCOL_VERSION_1);
        out.writeInt(port);
        out.flush();

        DataInputStream in = new DataInputStream(tunnel.getInputStream());
        int response = in.readInt();
        switch (response) {
            case TunnelServerSocket.ACK:
                break;
            case TunnelServerSocket.NACK_NO_SERVER:
                throw new ConnectException("Connection refused");
            case TunnelServerSocket.NACK_QUEUE_FULL:
                throw new ConnectException("Connection refused " +
                                           "(server busy, try again later)");
            default:
                throw new ConnectException();
        }

        // succeeded; reset the timeout
        if (deadline != 0 || soTimeout >= 0) {
            tunnel.setSoTimeout(soTimeout >= 0 ? soTimeout : 0);
        }
    }

    /**
     * Not supported.
     *
     * @param bindpoint
     * @throws IOException
     */
    public void bind(SocketAddress bindpoint) throws IOException {
        ensureNotClosed();
        throw new UnsupportedOperationException();
    }

    protected Socket getTunnel() { return tunnel; }

    /**
     */
    public InetAddress getInetAddress() {
        return tunnel.getInetAddress();
    }

    /**
     */
    public InetAddress getLocalAddress() {
        return tunnel.getLocalAddress();
    }

    public int getPort() {
        if (!isConnected()) return 0;
        return tunnel.getPort();
    }

    /**
     */
    public int getLocalPort() {
        if (!isBound()) return -1;
        return tunnel.getLocalPort();
    }

    /**
     * Returns server endpoint information as {@link InProcSocketAddress}.
     *
     * @return server endpoint information
     */
    public SocketAddress getRemoteSocketAddress() {
        if (!isConnected()) return null;
        return new TunnelSocketAddress(tunnel.getRemoteSocketAddress(), remotePort);
    }

    /**
     * Returns an {@link SocketAddress} holding local port number, as
     * obtained via {@link #getLocalPort}.
     *
     * @return client endpoint information
     */
    public SocketAddress getLocalSocketAddress() {
        if (!isBound()) return null;
        return tunnel.getLocalSocketAddress();
    }

    public synchronized InputStream getInputStream() throws IOException {
        ensureNotClosed();
        if (!isConnected()) throw new SocketException("Socket is not connected");
        if (isInputShutdown()) throw new SocketException("Socket input is shutdown");
        return tunnel.getInputStream();
    }

    public synchronized OutputStream getOutputStream() throws IOException {
        ensureNotClosed();
        if (!isConnected()) throw new SocketException("Socket is not connected");
        if (isOutputShutdown()) throw new SocketException("Socket output is shutdown");
        return tunnel.getOutputStream();
    }

    /**
     */
    public void setTcpNoDelay(boolean on) throws SocketException {
        tunnel.setTcpNoDelay(on);
    }

    /**
     */
    public boolean getTcpNoDelay() throws SocketException {
        return tunnel.getTcpNoDelay();
    }

    /**
     */
    public void setSoLinger(boolean on, int linger) throws SocketException {
        tunnel.setSoLinger(on, linger);
    }

    /**
     */
    public int getSoLinger() throws SocketException {
        return tunnel.getSoLinger();
    }

    /**
     */
    public void sendUrgentData (int data) throws IOException  {
        tunnel.sendUrgentData(data);
    }

    /**
     */
    public void setOOBInline(boolean on) throws SocketException {
        tunnel.setOOBInline(on);
    }

    /**
     */
    public boolean getOOBInline() throws SocketException {
        return tunnel.getOOBInline();
    }

    public synchronized void setSoTimeout(int timeout) throws SocketException {
        ensureNotClosed();
        if (timeout < 0)
            throw new IllegalArgumentException("timeout can't be negative; got: " + timeout);
        soTimeout = timeout;
    }

    public synchronized int getSoTimeout() throws SocketException {
        ensureNotClosed();
        return soTimeout;
    }

    public void setSendBufferSize(int size) throws SocketException {
        tunnel.setSendBufferSize(size);
    }

    public int getSendBufferSize() throws SocketException {
        return tunnel.getSendBufferSize();
    }

    public void setReceiveBufferSize(int size) throws SocketException{
        tunnel.setReceiveBufferSize(size);
    }

    public int getReceiveBufferSize() throws SocketException {
        return tunnel.getReceiveBufferSize();
    }

    /**
     */
    public void setKeepAlive(boolean on) throws SocketException {
        tunnel.setKeepAlive(on);
    }

    /**
     */
    public boolean getKeepAlive() throws SocketException {
        return tunnel.getKeepAlive();
    }

    /**
     */
    public void setTrafficClass(int tc) throws SocketException {
        tunnel.setTrafficClass(tc);
    }

    /**
     */
    public int getTrafficClass() throws SocketException {
        return tunnel.getTrafficClass();
    }

    /**
     */
    public void setReuseAddress(boolean on) throws SocketException {
        tunnel.setReuseAddress(on);
    }

    /**
     */
    public boolean getReuseAddress() throws SocketException {
        return tunnel.getReuseAddress();
    }

    public void close() throws IOException {
        tunnel.close();
    }

    public void shutdownInput() throws IOException {
        tunnel.shutdownInput();
    }

    public void shutdownOutput() throws IOException {
        tunnel.shutdownOutput();
    }

    public String toString() {
        if (isConnected()) {
            return "Socket[remote=" + getRemoteSocketAddress() +
                   ",local=" + getLocalSocketAddress() + "]";
        }
        else {
            return "Socket[unconnected]";
        }
    }


    synchronized void setConnected(int remotePort) {
        this.remotePort = remotePort;
        this.connected = true;
    }

    public boolean isConnected() {
        return connected;
    }

    public boolean isBound() {
        return true;
    }

    public boolean isClosed() {
        return tunnel.isClosed();
    }

    public boolean isInputShutdown() {
        return tunnel.isInputShutdown();
    }

    public boolean isOutputShutdown() {
        return tunnel.isOutputShutdown();
    }

    private void ensureNotClosed() throws SocketException {
        if (isClosed()) throw new SocketException("Socket is closed");
    }
}
